"""
Phase 6: Projections — Safe Asset Gap & Rating Migration
=========================================================
Forward-projects demographic demand using UN WPP 2024 medium variant
(already in full_panel.csv for year > 2024). Projects safe supply from
fiscal trajectories. Constructs "safe asset gap" and rating migration risk.

Output: table6_projections.md, table6b_rating_migration.md, figure1_safe_gap.png
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/safe_assets")
MULTILATERAL_DIR = PROJECT_DIR.parent / "multilateral"
FISCAL_DIR = PROJECT_DIR.parent / "fiscal_dominance"
sys.path.insert(0, str(MULTILATERAL_DIR / "src"))

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"
FIGURES_DIR = PROJECT_DIR / "output" / "figures"

for d in [TABLES_DIR, FIGURES_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Safe issuer rating history (from phase 1)
SAFE_ISSUERS_2024 = [
    'USA', 'DEU', 'CHE', 'NLD', 'DNK', 'NOR', 'SWE', 'SGP', 'LUX',
    'CAN', 'AUS', 'AUT', 'FIN', 'HKG', 'TWN', 'NZL',
    # GBR and FRA dropped below AA- after downgrades
]

# Countries at risk of downgrade (currently AA- to A+)
AT_RISK = {
    'GBR': {'current_rating': 'AA', 'risk': 'fiscal_pressure'},
    'FRA': {'current_rating': 'AA', 'risk': 'fiscal_pressure'},
    'KOR': {'current_rating': 'AA', 'risk': 'demographic_pressure'},
    'BEL': {'current_rating': 'AA-', 'risk': 'debt_trajectory'},
    'NZL': {'current_rating': 'AA+', 'risk': 'small_economy'},
    'FIN': {'current_rating': 'AA+', 'risk': 'demographic_pressure'},
    'AUT': {'current_rating': 'AA+', 'risk': 'eurozone_spillover'},
}


def main():
    print("=" * 70)
    print("PHASE 6: Projections — Safe Asset Gap & Rating Migration")
    print("=" * 70)

    # ── [1] Load full panel including projections (year > 2024) ──
    print("\n[1] Loading full panel with projections ...")
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    print(f"  Full panel: {fp['iso3'].nunique()} countries, {len(fp):,} obs, "
          f"{fp['year'].min()}-{fp['year'].max()}")

    # Separate historical and projected
    hist = fp[fp['year'] <= 2024].copy()
    proj = fp[fp['year'] > 2024].copy()
    print(f"  Historical: {len(hist):,} obs ({hist['year'].min()}-{hist['year'].max()})")
    print(f"  Projected: {len(proj):,} obs ({proj['year'].min()}-{proj['year'].max()})")

    # ── [2] Project demographic demand ──
    print("\n[2] Projecting demographic demand ...")

    # Global OADR weighted by latest GDP
    latest_gdp = hist[hist['year'] == 2024][['iso3', 'ngdp_usd']].dropna()
    latest_gdp = latest_gdp.rename(columns={'ngdp_usd': 'gdp_weight'})

    demand_proj = []
    for yr in sorted(proj['year'].unique()):
        yr_data = proj[proj['year'] == yr][['iso3', 'year', 'old_dep', 'Z_1']].merge(
            latest_gdp, on='iso3', how='inner')
        if len(yr_data) == 0 or 'old_dep' not in yr_data.columns:
            continue
        yr_valid = yr_data.dropna(subset=['old_dep', 'gdp_weight'])
        if len(yr_valid) == 0:
            continue
        global_oadr = np.average(yr_valid['old_dep'], weights=yr_valid['gdp_weight'])
        demand_proj.append({'year': yr, 'global_oadr_proj': global_oadr,
                            'n_countries': len(yr_valid)})

    # Also compute historical global OADR for comparison
    for yr in range(1990, 2025):
        yr_data = hist[hist['year'] == yr].dropna(subset=['old_dep', 'ngdp_usd']).copy()
        if len(yr_data) > 0:
            g_oadr = np.average(yr_data['old_dep'].values, weights=yr_data['ngdp_usd'].values)
            demand_proj.append({'year': yr, 'global_oadr_proj': g_oadr,
                                'n_countries': len(yr_data)})

    demand_df = pd.DataFrame(demand_proj).sort_values('year')
    print(f"  Demand projection: {demand_df['year'].min()}-{demand_df['year'].max()}")
    if len(demand_df) > 0:
        print(f"  Global OADR 2024: {demand_df[demand_df['year']==2024]['global_oadr_proj'].values[0]:.2f}")
        max_yr = demand_df['year'].max()
        print(f"  Global OADR {max_yr}: {demand_df[demand_df['year']==max_yr]['global_oadr_proj'].values[0]:.2f}")

    # ── [3] Project safe asset supply ──
    print("\n[3] Projecting safe asset supply ...")

    # Load fiscal panel for debt trajectories
    fisc_path = FISCAL_DIR / "data" / "processed" / "fiscal_panel.csv"
    if fisc_path.exists():
        fisc = pd.read_csv(fisc_path)
    else:
        fisc = hist

    supply_proj = []

    # For safe issuers, extrapolate debt/GDP trajectory
    for iso3 in SAFE_ISSUERS_2024:
        c_data = fisc[fisc['iso3'] == iso3].dropna(subset=['govt_debt_gdp']).sort_values('year')
        if len(c_data) < 5:
            c_data = hist[hist['iso3'] == iso3].dropna(subset=['govt_debt_gdp']).sort_values('year')
        if len(c_data) < 5:
            continue

        # Use last 10 years average annual change
        recent = c_data[c_data['year'] >= 2014]
        if len(recent) < 3:
            recent = c_data.tail(5)

        debt_2024 = c_data[c_data['year'] <= 2024].iloc[-1]['govt_debt_gdp']
        annual_change = recent['govt_debt_gdp'].diff().mean()

        gdp_row = hist[(hist['iso3'] == iso3) & (hist['year'] == 2024)]
        gdp_2024 = gdp_row['ngdp_usd'].values[0] if len(gdp_row) > 0 else np.nan

        for yr in range(2025, 2051):
            proj_debt = debt_2024 + annual_change * (yr - 2024)
            proj_debt = max(0, proj_debt)  # floor at 0
            supply_proj.append({
                'iso3': iso3, 'year': yr,
                'govt_debt_gdp_proj': proj_debt,
                'ngdp_usd_2024': gdp_2024,
            })

        # Also record historical
        for _, row in c_data.iterrows():
            gdp_h = hist[(hist['iso3'] == iso3) & (hist['year'] == row['year'])]
            gdp_val = gdp_h['ngdp_usd'].values[0] if len(gdp_h) > 0 else np.nan
            supply_proj.append({
                'iso3': iso3, 'year': int(row['year']),
                'govt_debt_gdp_proj': row['govt_debt_gdp'],
                'ngdp_usd_2024': gdp_val,
            })

    supply_df = pd.DataFrame(supply_proj)

    # Aggregate: total safe debt / total GDP
    if len(supply_df) > 0:
        supply_df['safe_debt_usd'] = (supply_df['govt_debt_gdp_proj'] / 100 *
                                       supply_df['ngdp_usd_2024'])
        yearly_supply = supply_df.groupby('year').agg(
            total_safe_debt=('safe_debt_usd', 'sum'),
            total_gdp=('ngdp_usd_2024', 'sum'),
            n_issuers=('iso3', 'nunique')
        ).reset_index()
        yearly_supply['safe_supply_ratio_proj'] = (yearly_supply['total_safe_debt'] /
                                                    yearly_supply['total_gdp'])
        print(f"  Supply projection: {yearly_supply['year'].min()}-{yearly_supply['year'].max()}")

    # ── [4] Construct safe asset gap ──
    print("\n[4] Constructing safe asset gap ...")

    gap_df = demand_df.merge(yearly_supply[['year', 'safe_supply_ratio_proj', 'n_issuers']],
                             on='year', how='outer').sort_values('year')

    # Standardize both to 2024 = 100
    base_demand = gap_df.loc[gap_df['year'] == 2024, 'global_oadr_proj'].values
    base_supply = gap_df.loc[gap_df['year'] == 2024, 'safe_supply_ratio_proj'].values

    if len(base_demand) > 0 and len(base_supply) > 0:
        gap_df['demand_index'] = gap_df['global_oadr_proj'] / base_demand[0] * 100
        gap_df['supply_index'] = gap_df['safe_supply_ratio_proj'] / base_supply[0] * 100
        gap_df['safe_asset_gap'] = gap_df['demand_index'] - gap_df['supply_index']

    # ── [5] Rating migration risk ──
    print("\n[5] Assessing rating migration risk ...")

    migration_records = []
    for iso3, info in AT_RISK.items():
        # Get demographic trajectory
        c_proj = fp[fp['iso3'] == iso3].sort_values('year')
        oadr_2024 = c_proj[c_proj['year'] == 2024]['old_dep'].values
        oadr_2024 = oadr_2024[0] if len(oadr_2024) > 0 else np.nan

        oadr_2040 = c_proj[c_proj['year'] == 2040]['old_dep'].values
        oadr_2040 = oadr_2040[0] if len(oadr_2040) > 0 else np.nan

        oadr_2050 = c_proj[c_proj['year'] == 2050]['old_dep'].values
        oadr_2050 = oadr_2050[0] if len(oadr_2050) > 0 else np.nan

        # Get fiscal trajectory
        c_fisc = fisc[fisc['iso3'] == iso3].sort_values('year')
        debt_latest = c_fisc[c_fisc['year'] <= 2024]['govt_debt_gdp'].dropna()
        debt_2024 = debt_latest.values[-1] if len(debt_latest) > 0 else np.nan

        migration_records.append({
            'iso3': iso3,
            'current_rating': info['current_rating'],
            'risk_type': info['risk'],
            'oadr_2024': oadr_2024,
            'oadr_2040': oadr_2040,
            'oadr_2050': oadr_2050,
            'debt_gdp_2024': debt_2024,
        })

    migration_df = pd.DataFrame(migration_records)

    # ── [6] Country-level projections for current safe issuers ──
    print("\n[6] Country-level demographic projections for safe issuers ...")

    issuer_proj = []
    for iso3 in SAFE_ISSUERS_2024:
        c_data = fp[fp['iso3'] == iso3].sort_values('year')
        for yr in [2024, 2030, 2040, 2050]:
            row = c_data[c_data['year'] == yr]
            if len(row) > 0:
                issuer_proj.append({
                    'iso3': iso3, 'year': yr,
                    'old_dep': row['old_dep'].values[0],
                    'Z_1': row['Z_1'].values[0] if 'Z_1' in row.columns else np.nan,
                })
    issuer_proj_df = pd.DataFrame(issuer_proj)

    # ── [7] Generate figure ──
    print("\n[7] Generating safe asset gap figure ...")
    try:
        import matplotlib
        matplotlib.use('Agg')
        import matplotlib.pyplot as plt

        fig, ax1 = plt.subplots(figsize=(10, 6))

        plot_data = gap_df[(gap_df['year'] >= 2000) &
                           (gap_df['year'] <= 2050)].copy()

        if 'demand_index' in plot_data.columns:
            ax1.plot(plot_data['year'], plot_data['demand_index'],
                     'b-', linewidth=2, label='Demographic Demand (OADR index)')
            ax1.plot(plot_data['year'], plot_data['supply_index'],
                     'r--', linewidth=2, label='Safe Supply (debt ratio index)')
            ax1.axvline(x=2024, color='gray', linestyle=':', alpha=0.5)
            ax1.text(2024.5, ax1.get_ylim()[1] * 0.95, 'Projection →',
                     fontsize=9, color='gray')
            ax1.set_xlabel('Year')
            ax1.set_ylabel('Index (2024 = 100)')
            ax1.set_title('Safe Asset Gap: Demographic Demand vs Safe Supply')
            ax1.legend(loc='upper left')
            ax1.grid(True, alpha=0.3)

            # Secondary axis: gap
            if 'safe_asset_gap' in plot_data.columns:
                ax2 = ax1.twinx()
                ax2.fill_between(plot_data['year'], 0, plot_data['safe_asset_gap'],
                                 alpha=0.15, color='purple')
                ax2.set_ylabel('Gap (demand - supply)', color='purple')
                ax2.tick_params(axis='y', labelcolor='purple')

        fig_path = FIGURES_DIR / "figure1_safe_gap.png"
        plt.tight_layout()
        plt.savefig(fig_path, dpi=150, bbox_inches='tight')
        plt.close()
        print(f"  Saved: {fig_path}")
    except ImportError:
        print("  matplotlib not available — skipping figure")

    # ── [8] Build tables ──
    print("\n[8] Building output tables ...")
    build_projections_table(gap_df, issuer_proj_df)
    build_migration_table(migration_df)

    print("\n" + "=" * 70)
    print("Phase 6 complete.")
    print("=" * 70)


def build_projections_table(gap_df, issuer_proj_df):
    """Table 6: Forward projections."""
    md = ["# Table 6: Safe Asset Gap Projections\n"]

    # Global trajectory
    md.append("## Global Demand-Supply Trajectory\n")
    md.append("| Year | Global OADR | Safe Supply Ratio | Demand Index | Supply Index | Gap |")
    md.append("|------|------------|-------------------|-------------|-------------|-----|")
    for yr in [2000, 2010, 2020, 2024, 2030, 2035, 2040, 2045, 2050]:
        row = gap_df[gap_df['year'] == yr]
        if len(row) > 0:
            r = row.iloc[0]
            oadr = f"{r['global_oadr_proj']:.2f}" if pd.notna(r.get('global_oadr_proj')) else '—'
            supply = f"{r['safe_supply_ratio_proj']:.4f}" if pd.notna(r.get('safe_supply_ratio_proj')) else '—'
            d_idx = f"{r['demand_index']:.1f}" if pd.notna(r.get('demand_index')) else '—'
            s_idx = f"{r['supply_index']:.1f}" if pd.notna(r.get('supply_index')) else '—'
            gap = f"{r['safe_asset_gap']:.1f}" if pd.notna(r.get('safe_asset_gap')) else '—'
            md.append(f"| {yr} | {oadr} | {supply} | {d_idx} | {s_idx} | {gap} |")

    # Country-level projections
    if len(issuer_proj_df) > 0:
        md.append("\n## Safe Issuer Demographic Trajectories\n")
        md.append("| Country | OADR 2024 | OADR 2030 | OADR 2040 | OADR 2050 |")
        md.append("|---------|-----------|-----------|-----------|-----------|")

        for iso3 in sorted(issuer_proj_df['iso3'].unique()):
            c = issuer_proj_df[issuer_proj_df['iso3'] == iso3]
            vals = {}
            for yr in [2024, 2030, 2040, 2050]:
                row = c[c['year'] == yr]
                vals[yr] = f"{row['old_dep'].values[0]:.1f}" if len(row) > 0 and pd.notna(row['old_dep'].values[0]) else '—'
            md.append(f"| {iso3} | {vals[2024]} | {vals[2030]} | {vals[2040]} | {vals[2050]} |")

    md.append("\n*Demand: GDP-weighted global OADR. Supply: debt/GDP of AA- or above issuers.*")
    md.append("*Index: 2024 = 100. Gap = demand index - supply index.*")
    md.append("*Demographic projections: UN WPP 2024 medium variant.*")

    out = TABLES_DIR / "table6_projections.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


def build_migration_table(migration_df):
    """Table 6b: Rating migration risk."""
    md = ["# Table 6b: Rating Migration Risk for Near-Safe Issuers\n"]

    md.append("| Country | Current Rating | Risk Type | OADR 2024 | OADR 2040 | OADR 2050 | Debt/GDP 2024 |")
    md.append("|---------|---------------|-----------|-----------|-----------|-----------|---------------|")

    for _, r in migration_df.iterrows():
        oadr24 = f"{r['oadr_2024']:.1f}" if pd.notna(r['oadr_2024']) else '—'
        oadr40 = f"{r['oadr_2040']:.1f}" if pd.notna(r['oadr_2040']) else '—'
        oadr50 = f"{r['oadr_2050']:.1f}" if pd.notna(r['oadr_2050']) else '—'
        debt = f"{r['debt_gdp_2024']:.1f}" if pd.notna(r['debt_gdp_2024']) else '—'
        md.append(f"| {r['iso3']} | {r['current_rating']} | {r['risk_type']} "
                  f"| {oadr24} | {oadr40} | {oadr50} | {debt} |")

    md.append("\n*Rating migration risk based on fiscal trajectory and demographic pressure.*")
    md.append("*From fiscal dominance paper: +10pp OADR → +12pp expenditure, +5pp revenue.*")
    md.append("*Countries currently at AA to AA+ that face downgrade pressure from aging.*")

    out = TABLES_DIR / "table6b_rating_migration.md"
    out.write_text('\n'.join(md))
    print(f"  Saved: {out}")


if __name__ == "__main__":
    main()
